Latent neural process (LNP): Derivation of an ELBO-like objective

Latent neural processes (LNPs) #1 use a training objective that is inspired from the ELBO. The steps for the derivation are the following:

Usual ELBO

Suppose we have a dataset D={(xi,yi)i=1n} that has been generated from a latent variable z by a model pθ(yDz,xD). We could estimate the parameters θ by maximum likelihood, by maximizing the marginal likelihood of yD=yii=1n. The log marginal likelihood can be decomposed as

logp(yDxD)=KL(q(zxD,yD)p(zxD,yD))+Eq[logp(yDxD,z)]+Eq[logp(z)q(zxD,yD)],

which results in the usual ELBO bound

logp(yDxD)Eq[logp(yDxD,z)]+Eq[logp(z)q(zxD,yD)].

Note that this is equivalent to the ELBO bound of a VAE for a single datapoint y:=yD originating from a latent variable z, with the only difference that here we are conditioning everything on the inputs xD too (by design).

Conditional ELBO (intractable)

Now, suppose that you partition the dataset into a context set C={(xi,yi)i=1m} and a target set T={(xi,yi)i=m+1n}. The goal is to infer yT given information from yC. We could try to obtain appropriate parameters θ by maximizing the same marginal likelihood as before, p(yDxD). However, this is an indirect objective, since it represents maximizing the likelihood of the entire dataset. What we really want to maximize is the conditional marginal likelihood p(yTxD,yC). Following the same VAE analogy as before, this would be equivalent to reconstructing part of a datapoint based on the rest of the datapoint (for example, reconstructing the left side of an image based on the right side).

We can obtain an ELBO bound for this conditional marginal likelihood as follows. The LHS of the usual ELBO can be rewritten as

logp(yDxD)=log(p(yTxD,yC)p(yCxD))=logp(yTxD,yC)+logp(yCxD)

Similarly, the RHS can be rewritten as

Eq[log(p(yTz,xD)p(yCz,xD)]+Eq[logp(z)p(zxD,yC)q(zxD,yD)p(zxD,yC)].

We can obtain the desired ELBO by substracting logp(yCxD) from both sides and collecting the operands in red in the RHS. In this way, after substracting this term the LHS becomes the desired bound. The RHS becomes

Eq[logp(yTz,xD)]+Eq[logp(zxD,yC)q(zxD,yD)]+Eq[logp(yCz,xD)p(z)p(zxD,yC)logp(yCxD)],

where the last term cancels out:

logp(yCz,xD)p(z)p(zxD,yC)logp(yCxD)=logp(yC,zxD)p(zxD,yC)p(yCxD)=logp(yC,zxD)p(yC,zxD)=log1=0.

Thus, we arrive at

logp(yTxD,yC)Eq[logp(yTz,xD)]+Eq[logp(zxD,yC)q(zxD,yD)],

which becomes the desired bound after making the (reasonable) assumption that each yi prediction can only be informed by its corresponding xi input, for example so that p(yTz,xD) becomes p(yTz,xT):

logp(yTxD,yC)Eq[logp(yTz,xT)]+Eq[logp(zxC,yC)q(zxD,yD)]=Eq[logp(yTz,xT)]KL(q(zxD,yD)p(zxC,yC)).

We would be done except for one problem: this expression is unfortunately intractable because p(zxC,yC)=p(yCxC,z)p(z)p(yCxC,z)p(z)dy is intractable.

Note that, since we have substracted the same term from the marginal likelihood and the ELBO, the overall KL divergence remains the same with respect to the original one, i.e. the overall KL divergence is still

KL(q(zxD,yD)p(zxD,yD)).

Conditional ELBO-like (tractable)

The previous ELBO is a bound of the desired marginal likelihood p(yTxD,yC), but it is intractable because p(zxC,yC). LNPs circumvent this issue by approximating this term with q(zxC,yC).

logp(yTxD,yC)Eq[logp(yTz,xT)]KL(q(zxD,yD)q(zxC,yC)).

The right term (the KL or regularization term) can now be interpreted in the following way: the variational distribution over the latent variable z should be the same when the model has access to full information about the function (xD,yD) and when the model has only partial information about the function (xC,yC). This seems a reasonable objective for NPs, which try to recover the whole yD,xD based only on yC,xC.

Note, however, that this ELBO-like objective is no longer an analytical lower bound of the conditional log marginal likelihood p(yTxD,yC), so there is no guarantee that we are maximizing the likelihood of the parameters anymore.


References

1 Garnelo et al 2018. Neural processes.

2 Garnelo et al 2018. Conditional neural processes.